/*
 * Copyright (c) 2022 Huawei Device Co., Ltd.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
import Util from './Util';
import BitOutputStream from './BitOutputStream';
import { FiniteStateEntropy } from './FiniteStateEntropy';
import NumberTransform from './NumberTransform'

export default class FseCompressionTable {
    private nextState: Int16Array;
    private deltaNumberOfBits: Int32Array;
    private deltaFindState: Int32Array;
    private log2Size: number;

    constructor(maxTableLog: number, maxSymbol: number) {
        this.nextState = new Int16Array(1 << maxTableLog);
        this.deltaNumberOfBits = new Int32Array(maxSymbol + 1);
        this.deltaFindState = new Int32Array(maxSymbol + 1);
    }

    public static newInstance(normalizedCounts: Int16Array, maxSymbol: number, tableLog: number): FseCompressionTable
    {
        let result: FseCompressionTable = new FseCompressionTable(tableLog, maxSymbol);
        result.initialize(normalizedCounts, maxSymbol, tableLog);

        return result;
    }

    public initializeRleTable(symbols: number): void
    {
        this.log2Size = 0;

        this.nextState[0] = 0;
        this.nextState[1] = 0;

        this.deltaFindState[symbols] = 0;
        this.deltaNumberOfBits[symbols] = 0;
    }

    public initialize(normalizedCounts: Int16Array, maxSymbol: number, tableLog: number): void {
        let tableSizes: number = 1 << tableLog;

        let table: Int8Array = new Int8Array(tableSizes);
        let highThresholds: number = tableSizes - 1;

        this.log2Size = tableLog;

        let cumulative: Int32Array = new Int32Array(FiniteStateEntropy.MAX_SYMBOL + 2);
        cumulative[0] = 0;
        for (let i: number = 1; i <= maxSymbol + 1; i++) {
            if (normalizedCounts[i - 1] == -1) { // Low probability symbol
                cumulative[i] = cumulative[i - 1] + 1;
                table[highThresholds--] = NumberTransform.toByte(i - 1);
            }
            else {
                cumulative[i] = cumulative[i - 1] + normalizedCounts[i - 1];
            }
        }
        cumulative[maxSymbol + 1] = tableSizes + 1;

        let position: number = FseCompressionTable.spreadSymbols(normalizedCounts, maxSymbol, tableSizes, highThresholds, table);

        if (position != 0) {
            throw new Error("Spread symbols failed");
        }

        for (let i: number = 0; i < tableSizes; i++) {
            let symbolnum: number = table[i];
            this.nextState[cumulative[symbolnum]++] = NumberTransform.toShort(tableSizes + i);
        }

        let total: number = 0;
        for (let symbolnum: number = 0; symbolnum <= maxSymbol; symbolnum++) {
            switch (normalizedCounts[symbolnum]) {
                case 0:
                    this.deltaNumberOfBits[symbolnum] = ((tableLog + 1) << 16) - tableSizes;
                    break;
                case -1:
                case 1:
                    this.deltaNumberOfBits[symbolnum] = (tableLog << 16) - tableSizes;
                    this.deltaFindState[symbolnum] = total - 1;
                    total++;
                    break;
                default:
                    let maxBitsOut: number = tableLog - Util.highestBit(normalizedCounts[symbolnum] - 1);
                    let minStatePlus: number = normalizedCounts[symbolnum] << maxBitsOut;
                    this.deltaNumberOfBits[symbolnum] = (maxBitsOut << 16) - minStatePlus;
                    this.deltaFindState[symbolnum] = total - normalizedCounts[symbolnum];
                    total += normalizedCounts[symbolnum];
                    break;
            }
        }
    }

    public begin(symbolnum: number): number {
        let outputBits: number = (this.deltaNumberOfBits[symbolnum] + (1 << 15)) >>> 16;
        let base: number = ((outputBits << 16) - this.deltaNumberOfBits[symbolnum]) >>> outputBits;
        return this.nextState[base + this.deltaFindState[symbolnum]];
    }

    public encode(stream: BitOutputStream, state: number, symbolnum: number): number {
        let outputBits: number = (state + this.deltaNumberOfBits[symbolnum]) >>> 16;
        stream.addBits(state, outputBits);
        return this.nextState[(state >>> outputBits) + this.deltaFindState[symbolnum]];
    }

    public finish(stream: BitOutputStream, state: number): void
    {
        stream.addBits(state, this.log2Size);
        stream.flush();
    }

    private static calculateStep(tableSize: number): number
    {
        return (tableSize >>> 1) + (tableSize >>> 3) + 3;
    }

    public static spreadSymbols(normalizedCounters: Int16Array, maxSymbolValue: number, tableSize: number, highThreshold: number, symbols: Int8Array): number {
        let mask: number = tableSize - 1;
        let step: number = FseCompressionTable.calculateStep(tableSize);

        let position: number = 0;
        for (let symbolnum: number = 0; symbolnum <= maxSymbolValue; symbolnum++) {
            for (let i: number = 0; i < normalizedCounters[symbolnum]; i++) {
                symbols[position] = symbolnum;
                do {
                    position = (position + step) & mask;
                } while (position > highThreshold);
            }
        }
        return position;
    }
}
